{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s8n4-kD8pWKl"
      },
      "source": [
        "<span style=\"color: red;\">Requirement when running in Goolge Colab</span>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pZ77N4JMpWKn"
      },
      "outputs": [],
      "source": [
        "!pip install diffusers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bq8GMLpLpWKo"
      },
      "source": [
        "#  Chapter 6 - Real-world image reconstruction\n",
        "\n",
        "Since we achieved our goal for synthetic images to apply facial related changes to a generated image, we would like to be able to do the same if any given arbitary image therefore in this chapter we will take a step back from attention layers modifications and focus on whther there is a combination of prompts and latents that can be fed to the model that would lead to the generation of any given image. The apprach is take from the paper, Null-text Inversion for Editing Real Images using Guided Diffusion Models (https://arxiv.org/abs/2211.09794) to precisely and efficiently reconstruct any real images throught the sampling process of Stable Diffusion with the relevant prompt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QiP6ys6xpWKp"
      },
      "source": [
        "This part, as before, it's been copied from the previous chapter and you can and run and move to the next cell"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f-SXtKMTpWKp"
      },
      "outputs": [],
      "source": [
        "import warnings\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, DDIMScheduler\n",
        "import torch\n",
        "import matplotlib.pyplot as plt\n",
        "from typing import Optional\n",
        "from tqdm import tqdm\n",
        "\n",
        "\n",
        "model_id = \"stabilityai/stable-diffusion-2-1-base\"\n",
        "\n",
        "pipe = StableDiffusionPipeline.from_pretrained(model_id)\n",
        "pipe = pipe.to(\"cuda\")\n",
        "\n",
        "prompt = \"A photo of a woman, straight hair, light blonde and pink hair, smiling expression, grey background\"\n",
        "\n",
        "prompt_embeds = pipe.encode_prompt(prompt=prompt, negative_prompt=\"\", device=\"cuda\", num_images_per_prompt=1, do_classifier_free_guidance=True)\n",
        "\n",
        "cond_prompt_embeds = prompt_embeds[0]\n",
        "uncond_prompt_embeds = prompt_embeds[1]\n",
        "\n",
        "prompt_embeds_combined = torch.cat([uncond_prompt_embeds, cond_prompt_embeds])\n",
        "\n",
        "initial_latents = torch.randn((1, pipe.unet.in_channels, pipe.unet.config.sample_size, pipe.unet.config.sample_size), generator=torch.Generator().manual_seed(22)).to(\"cuda\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lGx9GQPdpWKp"
      },
      "source": [
        "Before we can start the process we would first need to load a real-world image and convert them to latents space from pixel space through vae"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hLIBw5t6pWKq"
      },
      "outputs": [],
      "source": [
        "import torchvision\n",
        "from PIL import Image\n",
        "import requests\n",
        "\n",
        "url1 = 'https://raw.githubusercontent.com/OutofAi/StableFace/main/photo.png'\n",
        "filename1 = url1.split('/')[-1]\n",
        "response1 = requests.get(url1)\n",
        "with open(filename1, 'wb') as f:\n",
        "    f.write(response1.content)\n",
        "\n",
        "\n",
        "img = Image.open('photo.png')\n",
        "\n",
        "transform = torchvision.transforms.Compose([\n",
        "    torchvision.transforms.Resize((512, 512)),\n",
        "    torchvision.transforms.ToTensor()\n",
        "])\n",
        "\n",
        "loaded_image = transform(img).to(\"cuda\").unsqueeze(0)\n",
        "\n",
        "if loaded_image.shape[1] == 4:\n",
        "    loaded_image = loaded_image[:,:3,:,:]\n",
        "    \n",
        "with torch.no_grad():\n",
        "    encoded_image = pipe.vae.encode(loaded_image*2 - 1)\n",
        "    real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample()\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7HRspPd1pWKq"
      },
      "source": [
        "The idea is quite simple, considering backward and forward path of a scheduler is deterministic, hence adding or removing noise will yield the same results and as the u-net is predicting only the noise, we should be able to apply inverse scheduler instead of scheduler to run a sampling backward and reach from a real image to its relevant latent noise and later by using that latent noise and a relevant prompt to reconstruct the image with Stable Diffusion\n",
        "\n",
        "So we load the inverse scheduler with the same model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QlDVtMzapWKq"
      },
      "outputs": [],
      "source": [
        "inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder=\"scheduler\")\n",
        "scheduler = DDIMScheduler.from_pretrained(model_id, subfolder=\"scheduler\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SK5uvqHYpWKr"
      },
      "source": [
        "During the inverse process we won't be using CFG, so technically we won't be needing the unconditional state prompt, but for simplicity we keep it and only reduce the guidance_scale to 1 to ignore the CFG"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UYU1pzqQpWKr"
      },
      "outputs": [],
      "source": [
        "num_inference_steps = 10\n",
        "\n",
        "# notice we disabled the CFG here by setting guidance scale as 1\n",
        "guidance_scale = 1\n",
        "inverse_scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = inverse_scheduler.timesteps\n",
        "\n",
        "latents = real_image_latents\n",
        "\n",
        "inversed_latents = []\n",
        "\n",
        "with torch.no_grad():\n",
        "\n",
        "    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=\"Inference steps\"):\n",
        "\n",
        "        inversed_latents.append(latents)\n",
        "\n",
        "        latent_model_input = torch.cat([latents] * 2)\n",
        "\n",
        "        noise_pred = pipe.unet(\n",
        "            latent_model_input,\n",
        "            t,\n",
        "            encoder_hidden_states=prompt_embeds_combined,\n",
        "            cross_attention_kwargs=None,\n",
        "            return_dict=False,\n",
        "        )[0]\n",
        "\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
        "        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "\n",
        "        # using inverser_scheduler instead of scheduler\n",
        "        latents = inverse_scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n",
        "\n",
        "\n",
        "# initial state\n",
        "real_image_initial_latents = latents"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jdVqogTFpWKs"
      },
      "source": [
        "Lets display the inverse process latents for each step"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IFyph2GWpWKs"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "def display_latents(latents):\n",
        "    with torch.no_grad():\n",
        "        num_latents = len(latents)\n",
        "        images_np = []\n",
        "\n",
        "        for latent in latents:\n",
        "            image = pipe.vae.decode(latent / pipe.vae.config.scaling_factor, return_dict=False)[0]\n",
        "            image_np = image.squeeze(0).float().permute(1, 2, 0).detach().cpu().numpy()\n",
        "            image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())\n",
        "            images_np.append(image_np)\n",
        "\n",
        "        # Calculate the figure size based on the number of latents\n",
        "        fig_width = min(20, 2 * num_latents)  # Max width of 20, 2 inches per image\n",
        "        fig_height = 2  # Fixed height for all images\n",
        "\n",
        "        fig, axes = plt.subplots(1, num_latents, figsize=(fig_width, fig_height))\n",
        "\n",
        "        if num_latents == 1:\n",
        "            axes = [axes]  # Ensure axes is always iterable\n",
        "\n",
        "        for i, (ax, image_np) in enumerate(zip(axes, images_np)):\n",
        "            ax.imshow(image_np)\n",
        "            ax.axis('off')\n",
        "            ax.set_title(f'Latent {i}')\n",
        "\n",
        "        plt.tight_layout()\n",
        "        plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RBTWHU8epWKt"
      },
      "outputs": [],
      "source": [
        "display_latents(inversed_latents)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m83ZzYw2pWKt"
      },
      "source": [
        "Now Theoretically we should be able to utilise the initial latents and generate a real-world image through the Stable Diffusion pipeline"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N7W_oS1FpWKt"
      },
      "outputs": [],
      "source": [
        "guidance_scale = 7.5\n",
        "scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = scheduler.timesteps\n",
        "\n",
        "latents = real_image_initial_latents\n",
        "\n",
        "with torch.no_grad():\n",
        "\n",
        "  for i, t in tqdm(enumerate(timesteps)):\n",
        "\n",
        "    latent_model_input = torch.cat([latents] * 2)\n",
        "    noise_pred = pipe.unet(\n",
        "        latent_model_input,\n",
        "        t,\n",
        "        encoder_hidden_states=prompt_embeds_combined,\n",
        "        cross_attention_kwargs=None,\n",
        "        return_dict=False,\n",
        "    )[0]\n",
        "\n",
        "    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
        "    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "\n",
        "    latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n",
        "\n",
        "  image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]\n",
        "  image_np = image.squeeze(0).float().permute(1,2,0).detach().cpu()\n",
        "  image_np = image_np - image_np.min()\n",
        "  image_np = image_np / image_np.max()\n",
        "  plt.imshow(image_np)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4ZDqP7LapWKt"
      },
      "source": [
        "As you noticed the results are less that satisfactory, they resemble the real-world image to a degree but it doesn't capture all the intricate details of the image, the discripancy mainly stems from the CFG which introduces an amount of noise to the process, causing it to deviate from the real-world image. With that in mind we can now benefit from the unconditional state prompt to fix the deviation from the original steps and find a vector that moves our generation in each step to the desired result. In the paper this has been referred to as pivotal tuning\n",
        "\n",
        "Showed in the (modified) figure below, extracted from the original paper you can see the required process to calculate the mean squared error of each step to move the generation back to it original route. the null-text is our unconditional prompt or negative prompt which is what we trying to replace here"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WE-PoSc-pWKu"
      },
      "source": [
        "![image.png](assets/image6.png)\n",
        "\n",
        "((modified) image source: https://arxiv.org/abs/2211.09794)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gz2LYwUwpWKu"
      },
      "outputs": [],
      "source": [
        "import torch.nn as nn\n",
        "\n",
        "W_values = uncond_prompt_embeds.repeat(num_inference_steps, 1, 1)\n",
        "QT = nn.Parameter(W_values.clone())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IOTaONE0pWKu"
      },
      "outputs": [],
      "source": [
        "import torch.nn.functional as F\n",
        "import gc\n",
        "\n",
        "guidance_scale = 7.5\n",
        "scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = scheduler.timesteps\n",
        "\n",
        "optimizer = torch.optim.AdamW([QT], lr=0.008)\n",
        "\n",
        "pipe.vae.eval()\n",
        "pipe.vae.requires_grad_(False)\n",
        "pipe.unet.eval()\n",
        "pipe.unet.requires_grad_(False)\n",
        "\n",
        "last_loss = 1\n",
        "\n",
        "for epoch in range(50):\n",
        "    gc.collect()\n",
        "    torch.cuda.empty_cache()\n",
        "\n",
        "    intermediate_values = real_image_initial_latents.clone().requires_grad_(True)\n",
        "\n",
        "    if last_loss < 0.02:\n",
        "        break\n",
        "    elif last_loss < 0.03:\n",
        "        for param_group in optimizer.param_groups:\n",
        "            param_group['lr'] = 0.003\n",
        "    elif last_loss < 0.035:\n",
        "        for param_group in optimizer.param_groups:\n",
        "            param_group['lr'] = 0.006\n",
        "\n",
        "    for i in range(num_inference_steps):\n",
        "        latents = intermediate_values.detach().clone().requires_grad_(True)\n",
        "\n",
        "        t = timesteps[i]\n",
        "\n",
        "        prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds.detach()])\n",
        "\n",
        "        latent_model_input = torch.cat([latents] * 2)\n",
        "\n",
        "        noise_pred_model = pipe.unet(\n",
        "            latent_model_input,\n",
        "            t,\n",
        "            encoder_hidden_states=prompt_embeds,\n",
        "            cross_attention_kwargs=None,\n",
        "            return_dict=False,\n",
        "        )[0]\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text = noise_pred_model.chunk(2)\n",
        "        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "\n",
        "        intermediate_values = scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n",
        "\n",
        "        loss = F.mse_loss(inversed_latents[len(timesteps) - 1 - i].detach(), intermediate_values, reduction=\"mean\")\n",
        "        last_loss = loss\n",
        "\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        print(f\"Loss (epoch {epoch} - Step {i}): {loss.item()}\")\n",
        "    print(f\"Reconstruction Loss (epoch {epoch}): {last_loss.item()}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PySfRvpLpWKu"
      },
      "outputs": [],
      "source": [
        "guidance_scale = 7.5\n",
        "scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = scheduler.timesteps\n",
        "\n",
        "previous_latents = None\n",
        "\n",
        "with torch.no_grad():\n",
        "    gc.collect()\n",
        "    torch.cuda.empty_cache()\n",
        "    intermediate_values = real_image_initial_latents.clone().requires_grad_(True)\n",
        "\n",
        "    for i, t in enumerate(timesteps):\n",
        "        latents_value = intermediate_values.detach().clone().requires_grad_(True)\n",
        "\n",
        "\n",
        "        prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds.detach()])\n",
        "\n",
        "        latent_model_input = torch.cat([latents_value] * 2)\n",
        "\n",
        "        # Predict the noise residual\n",
        "        noise_pred_model = pipe.unet(\n",
        "            latent_model_input,\n",
        "            t,\n",
        "            encoder_hidden_states=prompt_embeds,\n",
        "            cross_attention_kwargs=None,\n",
        "            return_dict=False,\n",
        "        )[0]\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text = noise_pred_model.chunk(2)\n",
        "        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "\n",
        "        intermediate_values = scheduler.step(noise_pred, t, latents_value, return_dict=False)[0]\n",
        "\n",
        "\n",
        "    image = pipe.vae.decode(intermediate_values / pipe.vae.config.scaling_factor, return_dict=False)[0]\n",
        "    image_np = image.squeeze(0).float().permute(1,2,0).detach().cpu()\n",
        "    image_np = image_np - image_np.min()\n",
        "    image_np = image_np / image_np.max()\n",
        "    plt.imshow(image_np)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YcBs_TWFpWKv"
      },
      "source": [
        "# Save Training Data\n",
        "If you want to avoid re-running the training process in the next chapter, you can uncomment this next cell, and save the relevant data and skip the Training Section of the next chapter"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g4FOB6dfpWKv"
      },
      "outputs": [],
      "source": [
        "# combined_data = {\n",
        "#     'initial_latent': real_image_initial_latents,\n",
        "#     'QT': QT\n",
        "# }\n",
        "# torch.save(combined_data, \"reconstruction_data.pt\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4xKoaoQspWKw"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "L4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.11"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
